import matplotlib.pyplot as plt
import pandas as pd
import seaborn as sns

from cat_rfs.code.utils.tools import keep_perils
from cat_rfs.code.simulate.simulate import simulate_mkt
from cat_rfs.code.utils.graph import convert_to_percentages

SETTINGS = {'main': {'yield_var': 'sheet_yield_usd',
                     'wsst': False,
                     'fig_file': 'fig5'},
            'trace': {'yield_var': 'act_yield_usd',
                      'wsst': False,
                      'fig_file': 'fig5_trace'},
            'wsst': {'yield_var': 'sheet_yield_usd',
                     'wsst': True,
                     'fig_file': 'fig5_wsst'}}


def print_fig5(df, output='console', setting='main', reconstruct_variances=False):
    """
    Print Figure 5: Time series evolution of price of cyclone and earthquake risks.
    """
    fname = SETTINGS[setting]['fig_file']
    yield_ = SETTINGS[setting]['yield_var']

    df['perilsc'] = df['perils'].copy()
    df = keep_perils(df)
    df['perilss'] = df['perils'].copy()
    df['perils'] = df['perilsc'].copy()

    df = df[df['date'].dt.quarter == 2].copy()  # Jun only

    if SETTINGS[setting]['wsst']:
        for i in ['attachment_prob', 'expected_loss', 'exhaustion_prob']:
            df = df.drop(columns=[i]).rename(columns={i+'_wsst': i})

    if yield_ == 'act_yield_usd':
        df = df[(df['date'] >= '12/31/2008') & (df['date'] <= '12/31/2018')]

    df = df.rename(columns={yield_: 'yield'})
    df = df.dropna(subset=['yield'])

    if reconstruct_variances:
        dfh = simulate_mkt(df[df['perilss'] == 'Cyclone'].copy(), 500000, date_list=list(df['date'].unique()),
                           var='return', lambda_='actual', ew=False)

        dfe = simulate_mkt(df[df['perilss'] == 'Earthquake'].copy(), 500000, date_list=list(df['date'].unique()),
                           var='return', lambda_='actual', ew=False)

    df['er'] = df['yield'] - df['expected_loss'] / 100 - df['rf'] / 100

    df['er'] = df['er'] * df['size']
    df = df[df['perilss'].isin(('Cyclone', 'Earthquake'))].copy()
    df2 = df.groupby(['date', 'perilss'], as_index=False)[['er', 'size']].sum()
    df2['er'] = df2['er'] / df2['size']


    df2 = df2.pivot(index='date', columns=['perilss'], values=['er'])
    df2.columns = ['er_cyclone', 'er_earthquake']

    if reconstruct_variances:
        df2 = df2.merge(dfh, on='date')
        df2 = df2.merge(dfe, on='date', suffixes=['_cyclone', '_earthquake'])
    else:
        df3 = pd.read_csv('cat_rfs/data/ts.csv', parse_dates=['date']).rename(
            columns={'var_m_sheet_cyclone': 'var_m_cyclone', 'var_m_sheet_earthquake': 'var_m_earthquake'})
        df2 = df2.merge(df3, on='date')

    # %% ER / VAR
    df3 = df2.set_index('date').resample('D').interpolate(method='linear')
    df3['ratio_cyclone'] = df3['er_cyclone'] / df3['var_m_cyclone']
    df3['ratio_earthquake'] = df3['er_earthquake'] / df3['var_m_earthquake']

    df3 = df3[['ratio_cyclone', 'ratio_earthquake']].copy()
    df3 = df3.rename(columns={'ratio_cyclone': 'Cyclone portfolio', 'ratio_earthquake': 'Earthquake portfolio'})

    ax = df3.plot(style=['k-', 'k--'], figsize=(6.5, 3.5))

    DISASTERS = [[pd.to_datetime('7/1/2005'), pd.to_datetime('6/30/2006')],  # Katrina
                 [pd.to_datetime('7/1/2008'), pd.to_datetime('6/30/2009')],  # Ike, Lehman
                 [pd.to_datetime('7/1/2010'), pd.to_datetime('6/30/2011')],  # Tohoku
                 [pd.to_datetime('7/1/2012'), pd.to_datetime('6/30/2013')],  # Patricia
                 [pd.to_datetime('7/1/2015'), pd.to_datetime('6/30/2016')],  # Patricia
                 [pd.to_datetime('7/1/2017'), pd.to_datetime('6/30/2018')]]  # HIM
    i = DISASTERS[0]
    ax.axvspan(i[0], i[1], color='lightgray', label='Qualifying disaster(s)')

    for i in DISASTERS[1:]:
        ax.axvspan(i[0], i[1], color='lightgray')

    ax.legend(loc='upper right')
    ax.set_ylim(bottom=0)
    ax.set_xlim(df3.index.min(), df3.index.max())
    ax.set_ylabel(
        'Observed premium $\\left(\\frac{E_t\\left(R^e_{p}\\right)}{\\mathrm{Var}_t\\left(R^e_{p}\\right)}\\right)$')
    ax.set_xlabel('')

    if output == 'paper':
        plt.savefig(f'cat_rfs/output/figures/{fname}_a.eps')
        plt.close()
    else:
        plt.show()

    # %% ER
    sns.set_palette('Greys_r')
    df3 = df2.set_index('date').resample('D').interpolate(method='linear')
    df3 = df3[['er_cyclone', 'er_earthquake']].copy()
    df3 = df3.rename(columns={'er_cyclone': 'Cyclone portfolio', 'er_earthquake': 'Earthquake portfolio'})

    ax = df3.plot(style=['k-', 'k--'], figsize=(6.5, 3.5))

    DISASTERS = [[pd.to_datetime('7/1/2005'), pd.to_datetime('6/30/2006')],  # Katrina
                 [pd.to_datetime('7/1/2008'), pd.to_datetime('6/30/2009')],  # Ike, Lehman
                 [pd.to_datetime('7/1/2010'), pd.to_datetime('6/30/2011')],  # Tohoku
                 [pd.to_datetime('7/1/2012'), pd.to_datetime('6/30/2013')],  # Patricia
                 [pd.to_datetime('7/1/2015'), pd.to_datetime('6/30/2016')],  # Patricia
                 [pd.to_datetime('7/1/2017'), pd.to_datetime('6/30/2018')]]  # HIM
    i = DISASTERS[0]
    ax.axvspan(i[0], i[1], color='lightgray', label='Qualifying disaster(s)')

    for i in DISASTERS[1:]:
        ax.axvspan(i[0], i[1], color='lightgray')

    ax.legend(loc='upper right')
    ax.set_ylim(bottom=0)
    ax.set_xlim(df3.index.min(), df3.index.max())
    # ax.set_ylabel('Intermediation constraint ($\\alpha$)')
    ax.set_ylabel('Expected return ($E_t\\left(R^e_{p}\\right)$)')
    ax.set_xlabel('')
    ax = convert_to_percentages(ax)

    if output == 'paper':
        plt.savefig(f'cat_rfs/output/figures/{fname}_b.eps')
        plt.close()
    else:
        print('Execution completed. Close figures to exit.')
        plt.show(block=True)
    pass
